package com.novel.framework.webSocket;

import com.novel.common.utils.StringUtils;
import lombok.Data;
import org.springframework.stereotype.Component;

import javax.websocket.*;
import javax.websocket.server.PathParam;
import javax.websocket.server.ServerEndpoint;
import java.util.Collection;
import java.util.Hashtable;
import java.util.Map;
import java.util.concurrent.CopyOnWriteArraySet;

/**
 * websocket
 *
 * @author novel
 * @date 2019/5/22
 */
@ServerEndpoint(value = "/websocket/{userId}/{id}")
@Component
@Data
public class MyWebsocket {

    private Session session;

    private static CopyOnWriteArraySet<MyWebsocket> webSockets = new CopyOnWriteArraySet<>();
    /**
     * session 池
     */
    // private static Map<String, Session> sessionPool = new HashMap<>();

    private static Map<String, Map<String, Session>> sessionPool = new Hashtable<>();

    /**
     * 连接建立成功调用的方法
     */
    @OnOpen
    public void onOpen(Session session, @PathParam(value = "userId") String userId, @PathParam(value = "id") String id) {
        this.session = session;
        webSockets.add(this);
        //查看session池中是否存在当前用户的session
        Map<String, Session> sessionMap = sessionPool.computeIfAbsent(userId, k -> new Hashtable<>());
        sessionMap.put(id, session);
        System.out.println("【websocket消息】有新的连接，总数为:" + webSockets.size());
    }

    /**
     * 连接关闭调用的方法
     */
    @OnClose
    public void onClose(@PathParam(value = "userId") String userId, @PathParam(value = "id") String id) {
        webSockets.remove(this);
        Map<String, Session> sessionMap = sessionPool.get(userId);
        if (sessionMap != null) {
            sessionMap.remove(id);
        }
        //sessionPool.remove(ShiroUtils.getUserId().toString());

        System.out.println("有一连接关闭");
    }

    /**
     * 收到客户端消息后调用的方法
     *
     * @param message 客户端发送过来的消息
     */
    @OnMessage
    public void onMessage(String message, Session session) {
        System.out.println("来自客户端的消息:" + message);
    }

    /**
     * 发生错误时调用
     */
    @OnError
    public void onError(Session session, @PathParam(value = "userId") String userId, @PathParam(value = "id") String id, Throwable error) {
        System.out.println("发生错误");
        error.printStackTrace();
        webSockets.remove(this);
        Map<String, Session> sessionMap = sessionPool.get(userId);
        if (sessionMap != null) {
            sessionMap.remove(id);
        }
        // sessionPool.remove(ShiroUtils.getUserId().toString());
    }


    // 此为广播消息
    public void sendAllMessage(String message) {
        for (MyWebsocket webSocket : webSockets) {
            System.out.println("【websocket消息】广播消息:" + message);
            try {
                webSocket.session.getAsyncRemote().sendText(message);
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
    }

    // 此为单点消息
    public void sendOneMessage(String userId, String id, String message) {
        Map<String, Session> sessionMap = sessionPool.get(userId);
        if (sessionMap != null) {
            try {
                //如果id为空，则对所有的当前用户id都发送
                if (StringUtils.isEmpty(id)) {
                    Collection<Session> values = sessionMap.values();
                    values.forEach(session -> session.getAsyncRemote().sendText(message));
                } else {
                    Session session = sessionMap.get(id);
                    if (session != null) {
                        session.getAsyncRemote().sendText(message);
                    }
                }
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
    }
}
